{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adam算法\n",
    ":label:`sec_adam`\n",
    "\n",
    "本章我们已经学习了许多有效优化的技术。\n",
    "在本节讨论之前，我们先详细回顾一下这些技术：\n",
    "\n",
    "* 在 :numref:`sec_sgd`中，我们学习了：随机梯度下降在解决优化问题时比梯度下降更有效。\n",
    "* 在 :numref:`sec_minibatch_sgd`中，我们学习了：在一个小批量中使用更大的观测值集，可以通过向量化提供额外效率。这是高效的多机、多GPU和整体并行处理的关键。\n",
    "* 在 :numref:`sec_momentum`中我们添加了一种机制，用于汇总过去梯度的历史以加速收敛。\n",
    "* 在 :numref:`sec_adagrad`中，我们使用每坐标缩放来实现计算效率的预处理。\n",
    "* 在 :numref:`sec_rmsprop`中，我们通过学习率的调整来分离每个坐标的缩放。\n",
    "\n",
    "Adam算法 :cite:`Kingma.Ba.2014`将所有这些技术汇总到一个高效的学习算法中。\n",
    "不出预料，作为深度学习中使用的更强大和有效的优化算法之一，它非常受欢迎。\n",
    "但是它并非没有问题，尤其是 :cite:`Reddi.Kale.Kumar.2019`表明，有时Adam算法可能由于方差控制不良而发散。\n",
    "在完善工作中， :cite:`Zaheer.Reddi.Sachan.ea.2018`给Adam算法提供了一个称为Yogi的热补丁来解决这些问题。\n",
    "下面我们了解一下Adam算法。\n",
    "\n",
    "## 算法\n",
    "\n",
    "Adam算法的关键组成部分之一是：它使用指数加权移动平均值来估算梯度的动量和第二力矩，即它使用状态变量\n",
    "\n",
    "$$\\begin{aligned}\n",
    "    \\mathbf{v}_t & \\leftarrow \\beta_1 \\mathbf{v}_{t-1} + (1 - \\beta_1) \\mathbf{g}_t, \\\\\n",
    "    \\mathbf{s}_t & \\leftarrow \\beta_2 \\mathbf{s}_{t-1} + (1 - \\beta_2) \\mathbf{g}_t^2.\n",
    "\\end{aligned}$$\n",
    "\n",
    "这里$\\beta_1$和$\\beta_2$是非负加权参数。\n",
    "他们的常见设置是$\\beta_1 = 0.9$和$\\beta_2 = 0.999$。\n",
    "也就是说，方差的估计比动量的估计移动得远远更慢。\n",
    "注意，如果我们初始化$\\mathbf{v}_0 = \\mathbf{s}_0 = 0$，就会获得一个相当大的初始偏差。\n",
    "我们可以通过使用$\\sum_{i=0}^t \\beta^i = \\frac{1 - \\beta^t}{1 - \\beta}$来解决这个问题。\n",
    "相应地，标准化状态变量由以下获得\n",
    "\n",
    "$$\\hat{\\mathbf{v}}_t = \\frac{\\mathbf{v}_t}{1 - \\beta_1^t} \\text{ and } \\hat{\\mathbf{s}}_t = \\frac{\\mathbf{s}_t}{1 - \\beta_2^t}.$$\n",
    "\n",
    "有了正确的估计，我们现在可以写出更新方程。\n",
    "首先，我们以非常类似于RMSProp算法的方式重新缩放梯度以获得\n",
    "\n",
    "$$\\mathbf{g}_t' = \\frac{\\eta \\hat{\\mathbf{v}}_t}{\\sqrt{\\hat{\\mathbf{s}}_t} + \\epsilon}.$$\n",
    "\n",
    "与RMSProp不同，我们的更新使用动量$\\hat{\\mathbf{v}}_t$而不是梯度本身。\n",
    "此外，由于使用$\\frac{1}{\\sqrt{\\hat{\\mathbf{s}}_t} + \\epsilon}$而不是$\\frac{1}{\\sqrt{\\hat{\\mathbf{s}}_t + \\epsilon}}$进行缩放，两者会略有差异。\n",
    "前者在实践中效果略好一些，因此与RMSProp算法有所区分。\n",
    "通常，我们选择$\\epsilon = 10^{-6}$，这是为了在数值稳定性和逼真度之间取得良好的平衡。\n",
    "\n",
    "最后，我们简单更新：\n",
    "\n",
    "$$\\mathbf{x}_t \\leftarrow \\mathbf{x}_{t-1} - \\mathbf{g}_t'.$$\n",
    "\n",
    "回顾Adam算法，它的设计灵感很清楚：\n",
    "首先，动量和规模在状态变量中清晰可见，\n",
    "他们相当独特的定义使我们移除偏项目（这可以通过稍微不同的初始化和更新条件来修正）。\n",
    "其次，RMSProp算法中两个项目的组合都非常简单。\n",
    "最后，明确的学习率$\\eta$使我们能够控制步长来解决收敛问题。\n",
    "\n",
    "## 实现\n",
    "\n",
    "从头开始实现Adam算法并不难。\n",
    "为方便起见，我们将时间步$t$存储在`hyperparams`字典中。\n",
    "除此之外，一切都很简单。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/GradDescUtils.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/TrainingChapter11.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList initAdamStates(int featureDimension) {\n",
    "    NDManager manager = NDManager.newBaseManager();\n",
    "    NDArray vW = manager.zeros(new Shape(featureDimension, 1));\n",
    "    NDArray vB = manager.zeros(new Shape(1));\n",
    "    NDArray sW = manager.zeros(new Shape(featureDimension, 1));\n",
    "    NDArray sB = manager.zeros(new Shape(1));\n",
    "    return new NDList(vW, sW, vB, sB);\n",
    "}\n",
    "\n",
    "public class Optimization {\n",
    "    public static void adam(NDList params, NDList states, Map<String, Float> hyperparams) {\n",
    "        float beta1 = 0.9f;\n",
    "        float beta2 = 0.999f;\n",
    "        float eps = (float) 1e-6;\n",
    "        float time = hyperparams.get(\"time\");\n",
    "        for (int i = 0; i < params.size(); i++) {\n",
    "            NDArray param = params.get(i);\n",
    "            NDArray velocity = states.get(2 * i);\n",
    "            NDArray state = states.get(2 * i + 1);\n",
    "            // Update parameter, velocity, and state\n",
    "            // velocity = beta1 * v + (1 - beta1) * param.gradient\n",
    "            velocity.muli(beta1).addi(param.getGradient().mul(1 - beta1));\n",
    "            // state = beta2 * state + (1 - beta2) * param.gradient^2\n",
    "            state.muli(beta2).addi(param.getGradient().square().mul(1 - beta2));\n",
    "            // vBiasCorr = velocity / ((1 - beta1)^(time))\n",
    "            NDArray vBiasCorr = velocity.div(1 - Math.pow(beta1, time));\n",
    "            // sBiasCorr = state / ((1 - beta2)^(time))\n",
    "            NDArray sBiasCorr = state.div(1 - Math.pow(beta2, time));\n",
    "            // param -= lr * vBiasCorr / (sBiasCorr^(1/2) + eps)\n",
    "            param.subi(vBiasCorr.mul(hyperparams.get(\"lr\")).div(sBiasCorr.sqrt().add(eps)));\n",
    "        }\n",
    "        hyperparams.put(\"time\", time + 1);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，我们用以上Adam算法来训练模型，这里我们使用$\\eta = 0.01$的学习率。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AirfoilRandomAccess airfoil = TrainingChapter11.getDataCh11(10, 1500);\n",
    "\n",
    "public TrainingChapter11.LossTime trainAdam(float lr, float time, int numEpochs) throws IOException, TranslateException {\n",
    "    int featureDimension = airfoil.getColumnNames().size();\n",
    "    Map<String, Float> hyperparams = new HashMap<>();\n",
    "    hyperparams.put(\"lr\", lr);\n",
    "    hyperparams.put(\"time\", time);\n",
    "    return TrainingChapter11.trainCh11(Optimization::adam, \n",
    "                                       initAdamStates(featureDimension), \n",
    "                                       hyperparams, airfoil, \n",
    "                                       featureDimension, numEpochs);\n",
    "}\n",
    "\n",
    "TrainingChapter11.LossTime lossTime = trainAdam(0.01f, 1, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此外，我们可以用深度学习框架自带算法应用Adam算法，这里我们只需要传递配置参数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Tracker lrt = Tracker.fixed(0.01f);\n",
    "Optimizer adam = Optimizer.adam().optLearningRateTracker(lrt).build();\n",
    "\n",
    "TrainingChapter11.trainConciseCh11(adam, airfoil, 2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Yogi\n",
    "\n",
    "Adam算法也存在一些问题：\n",
    "即使在凸环境下，当$\\mathbf{s}_t$的第二力矩估计值爆炸时，它可能无法收敛。\n",
    " :cite:`Zaheer.Reddi.Sachan.ea.2018`为$\\mathbf{s}_t$提出了的改进更新和参数初始化。\n",
    "论文中建议我们重写Adam算法更新如下：\n",
    "\n",
    "$$\\mathbf{s}_t \\leftarrow \\mathbf{s}_{t-1} + (1 - \\beta_2) \\left(\\mathbf{g}_t^2 - \\mathbf{s}_{t-1}\\right).$$\n",
    "\n",
    "每当$\\mathbf{g}_t^2$具有高变量或更新稀疏时，$\\mathbf{s}_t$可能会太快地“忘记”过去的值。\n",
    "一个有效的解决方法是将$\\mathbf{g}_t^2 - \\mathbf{s}_{t-1}$替换为$\\mathbf{g}_t^2 \\odot \\mathop{\\mathrm{sgn}}(\\mathbf{g}_t^2 - \\mathbf{s}_{t-1})$。\n",
    "这就是Yogi更新，现在更新的规模不再取决于偏差的量。\n",
    "\n",
    "$$\\mathbf{s}_t \\leftarrow \\mathbf{s}_{t-1} + (1 - \\beta_2) \\mathbf{g}_t^2 \\odot \\mathop{\\mathrm{sgn}}(\\mathbf{g}_t^2 - \\mathbf{s}_{t-1}).$$\n",
    "\n",
    "论文中，作者还进一步建议用更大的初始批量来初始化动量，而不仅仅是初始的逐点估计。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class Optimization {\n",
    "    public static void yogi(NDList params, NDList states, Map<String, Float> hyperparams) {\n",
    "        float beta1 = 0.9f;\n",
    "        float beta2 = 0.999f;\n",
    "        float eps = (float) 1e-3;\n",
    "        float time = hyperparams.get(\"time\");\n",
    "        for (int i = 0; i < params.size(); i++) {\n",
    "            NDArray param = params.get(i);\n",
    "            NDArray velocity = states.get(2 * i);\n",
    "            NDArray state = states.get(2 * i + 1);\n",
    "            // Update parameter, velocity, and state\n",
    "            // velocity = beta1 * v + (1 - beta1) * param.gradient\n",
    "            velocity.muli(beta1).addi(param.getGradient().mul(1 - beta1));\n",
    "            /* Rewritten Update */\n",
    "            // state = state + (1 - beta2) * sign(param.gradient^2 - state) \n",
    "            //         * param.gradient^2\n",
    "            state.addi(param.getGradient().square().sub(state).sign().mul(1 - beta2));\n",
    "            // vBiasCorr = velocity / ((1 - beta1)^(time))\n",
    "            NDArray vBiasCorr = velocity.div(1 - Math.pow(beta1, time));\n",
    "            // sBiasCorr = state / ((1 - beta2)^(time))\n",
    "            NDArray sBiasCorr = state.div(1 - Math.pow(beta2, time));\n",
    "            // param -= lr * vBiasCorr / (sBiasCorr^(1/2) + eps)\n",
    "            param.subi(vBiasCorr.mul(hyperparams.get(\"lr\")).div(sBiasCorr.sqrt().add(eps)));\n",
    "        }\n",
    "        hyperparams.put(\"time\", time + 1);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AirfoilRandomAccess airfoil = TrainingChapter11.getDataCh11(10, 1500);\n",
    "\n",
    "public TrainingChapter11.LossTime trainYogi(float lr, float time, int numEpochs) throws IOException, TranslateException {\n",
    "    int featureDimension = airfoil.getColumnNames().size();\n",
    "    Map<String, Float> hyperparams = new HashMap<>();\n",
    "    hyperparams.put(\"lr\", lr);\n",
    "    hyperparams.put(\"time\", time);\n",
    "    return TrainingChapter11.trainCh11(Optimization::yogi, \n",
    "                                       initAdamStates(featureDimension), \n",
    "                                       hyperparams, airfoil, \n",
    "                                       featureDimension, numEpochs);\n",
    "}\n",
    "\n",
    "TrainingChapter11.LossTime lossTime = trainYogi(0.01f, 1, 2);\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* Adam算法将许多优化算法的功能结合到了相当强大的更新规则中。\n",
    "* Adam算法在RMSProp算法基础上创建的，还在小批量的随机梯度上使用EWMA。\n",
    "* 在估计动量和第二力矩时，Adam算法使用偏差校正来调整缓慢的启动速度。\n",
    "* 对于具有显著差异的梯度，我们可能会遇到收敛性问题。我们可以通过使用更大的小批量或者切换到改进的估计值$\\mathbf{s}_t$来修正它们。Yogi提供了这样的替代方案。\n",
    "\n",
    "## 练习\n",
    "\n",
    "1. 调节学习率，观察并分析实验结果。\n",
    "1. 你能重写动量和第二力矩更新，从而使其不需要偏差校正吗？\n",
    "1. 当我们收敛时，为什么你需要降低学习率$\\eta$？\n",
    "1. 尝试构造一个使用Adam算法会发散而Yogi会收敛的例子。\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
